
package org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.sink;

import com.datastax.oss.driver.api.core.CqlSession;
import com.datastax.oss.driver.api.core.cql.*;
import com.datastax.oss.driver.api.core.type.DataType;
import lombok.extern.slf4j.Slf4j;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.common.utils.ExceptionUtils;
import org.apache.seatunnel.connectors.seatunnel.common.sink.AbstractSinkWriter;
import org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.Client.GaussDBClient;
import org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.config.GaussDBCassandraParameters;
import org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.exception.GaussDBCassandraErrorCode;
import org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.exception.GaussDBConnectorException;
import org.apache.seatunnel.connectors.seatunnel.gaussdbcassandra.serialize.TypeConvertUtil;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
public class GaussDBCassandraSinkWriter extends AbstractSinkWriter<SeaTunnelRow, Void> {

    private final GaussDBCassandraParameters gaussDBParameters;
    private final SeaTunnelRowType seaTunnelRowType;
    private final ColumnDefinitions tableSchema;
    private final CqlSession session;
    private BatchStatement batchStatement;
    private List<BoundStatement> boundStatementList;
    private List<CompletionStage<AsyncResultSet>> completionStages;
    private final PreparedStatement preparedStatement;
    private final AtomicInteger counter = new AtomicInteger(0);

    public GaussDBCassandraSinkWriter(GaussDBCassandraParameters gaussDBParameters, SeaTunnelRowType seaTunnelRowType,
                                      ColumnDefinitions tableSchema) {
        this.gaussDBParameters = gaussDBParameters;
        this.seaTunnelRowType = seaTunnelRowType;
        this.tableSchema = tableSchema;
        this.session = GaussDBClient.getCqlSessionBuilder(gaussDBParameters.getHost(), gaussDBParameters.getKeyspace(),
                        gaussDBParameters.getUsername(), gaussDBParameters.getPassword(), gaussDBParameters.getDatacenter())
                .build();
        this.batchStatement = BatchStatement.builder(gaussDBParameters.getBatchType()).build();
        this.boundStatementList = new ArrayList<>();
        this.completionStages = new ArrayList<>();
        this.preparedStatement = session.prepare(initPrepareCQL());
    }

    /**
     * The function of this method is to write SeaTunnelRow data into the database,
     * and support
     * writing according to the specified batch size to improve the writing
     * efficiency.
     *
     * @param row the data need be written.
     * @throws IOException
     */
    @Override
    public void write(SeaTunnelRow row) throws IOException {
        BoundStatement boundStatement = this.preparedStatement.bind();
        addIntoBatch(row, boundStatement);
        if (counter.getAndIncrement() >= gaussDBParameters.getBatchSize()) {
            flush();
            counter.set(0);
        }
    }

    @Override
    public void close() throws IOException {
        flush();
        try {
            if (this.session != null) {
                this.session.close();
            }
        } catch (GaussDBConnectorException e) {
            throw new GaussDBConnectorException(GaussDBCassandraErrorCode.CLOSE_CQL_SESSION_FAILED, e);
        }
    }

    /**
     * write the cached data into Cassandra database in batches.
     */
    private void flush() {
        if (gaussDBParameters.getAsyncWrite()) {
            completionStages.forEach(resultStage -> resultStage.whenComplete((resultSet, error) -> {
                if (error != null) {
                    log.error(ExceptionUtils.getMessage(error));
                }
            }));
            completionStages.clear();
        } else {
            try {
                this.session.execute(this.batchStatement.addAll(boundStatementList));
            } catch (GaussDBConnectorException e) {
                log.error("Batch insert error,Try inserting one by one!", e);
                for (BoundStatement statement : boundStatementList) {
                    this.session.execute(statement);
                }
            } finally {
                this.batchStatement.clear();
                this.boundStatementList.clear();
            }
        }
    }

    /**
     * generate precompiled statements of Cassandra database.
     *
     * @return
     */
    private String initPrepareCQL() {
        String[] placeholder = new String[gaussDBParameters.getFields().size()];
        Arrays.fill(placeholder, "?");
        return String.format("INSERT INTO %s (%s) VALUES (%s)", gaussDBParameters.getTable(),
                String.join(",", gaussDBParameters.getFields()), String.join(",", placeholder));
    }

    /**
     * to add a piece of data to the batch write operation.
     *
     * @param row
     * @param boundStatement
     */
    private void addIntoBatch(SeaTunnelRow row, BoundStatement boundStatement) {
        try {
            for (int i = 0; i < gaussDBParameters.getFields().size(); i++) {
                String fieldName = gaussDBParameters.getFields().get(i);
                DataType dataType = tableSchema.get(i).getType();
                Object fieldValue = row.getField(seaTunnelRowType.indexOf(fieldName));
                boundStatement = TypeConvertUtil.reconvertAndInject(boundStatement, i, dataType, fieldValue);
            }
            if (gaussDBParameters.getAsyncWrite()) {
                completionStages.add(session.executeAsync(boundStatement));
            } else {
                boundStatementList.add(boundStatement);
            }
        } catch (GaussDBConnectorException e) {
            throw new GaussDBConnectorException(GaussDBCassandraErrorCode.ADD_BATCH_DATA_FAILED, e);
        }
    }
}
